% process Fista based spectrum.
% This file is common to all directories.
% ==========================================
% links
% 1) Proximal gradient method explained.
% https://www.youtube.com/watch?v=sy4pRJ3g530
% 2) Explain relation between Frobenius and operator norms of a matrix.
% https://math.stackexchange.com/questions/252819/why-is-the-frobenius-norm-of-a-matrix-greater-than-or-equal-to-the-spectral-norm
% papers
% 1) SJ Kim, K. Koh, M. Lustig, S. Boyd, D. Gorinevsky, IEEE Journal of selected topics in signal processing, 1(4), December 2007.
% =====================================================================================

function [ppm1_axis,Res_flip,SNR_fista,SNRt_fista,expt,FID] = process_Fista_spectrum(directory,input_params)

% Import experiment parameters
% =============================
lambda_val = input_params.lambda_val;
T1 = input_params.T1;
T2 = input_params.T2;
check_ssfp = input_params.check_ssfp;
remove_dc = input_params.remove_dc;
Nskip = input_params.Nskip;
min_sw_Hz = input_params.min_sw_Hz;
min_npoints = input_params.min_npoints;
lines_to_skip = input_params.lines_to_skip;
run_km_calib = input_params.run_km_calib;
RF_phase = input_params.RF_phase;
n_phases = input_params.n_phases;
add_band = input_params.add_band;
sw_display_Hz = input_params.sw_display_Hz;
signal_bound_ppm = input_params.signal_bound_ppm;
noise_bound_ppm = input_params.noise_bound_ppm;

% Reading acqu files
% ===================
tmp = dir(directory);
count = 0;
for jj = 1:length(tmp)
    if strcmp(tmp(jj).name, 'acqu3s') == 1
        count = 1;
        break;
    end
end

filetext = fileread([directory,filesep,'acqus']);
if count == 1
    filetext1 = fileread([directory,filesep, 'acqu3s']);
    parameters = read_params_from_file(filetext, filetext1);
else
    filetext1 = fileread([directory,filesep,'acqu2s']);
    parameters = read_params_from_file_2(filetext, filetext1);
end

% strip fields from struct "params".
sr = {'TR', 'FA', 'sw', 'npoints', 'freq_read', 'larmor_freq', 'Nshift','NS', 'DS'};
for j = 1:length(sr)
    eval([sr{j}, ' = parameters.', sr{j}, ';']);
end

% npoints and offres
% ================================
npoints = npoints/2;       % each point in 'npoints' is complex.
offres = freq_read;

band_num = Nshift + add_band;     % number of bands.


% opts_km definitions for km lasso_fista.
% =============================================
% opts struct values for km calculation.
opts_km.lambda = lambda_val;
opts_km.tol = 1e-3;
opts_km.verbose = false;

% opts definitions for spectrum lasso_fista.
% =================================================
opts.lambda = lambda_val;
opts.tol = 1e-4;

% read data
% ================
fileID = fopen([directory, filesep, 'ser']);
Data = fread(fileID,'int32');  % read raw data from ser file.
fclose(fileID);
Data = Data(1:2:end) + 1j*Data(2:2:end);
FID = Data;
Data = reshape(Data, npoints, length(Data)/(Nshift*npoints), Nshift);

% lines_to_skip between 0 to size(Data, 2) - 1.
% ======================================================
lines_to_skip = max(min(size(Data, 2) - 1, lines_to_skip), 0);
if size(Data, 2) > 1
    acquired_lines = size(Data, 2) - lines_to_skip;
else
    acquired_lines = 'full';
end

data = squeeze(mean(Data(:, lines_to_skip + 1:end, :), 2));

data = data(Nskip + 1:end, :);
data = data - mean(data)*remove_dc;       % remove dc.

data = flip(fftshift(data, 2), 2);
measured_data = data;

Nshift_vec = (0:Nshift - 1) - floor(Nshift/2);       % -Nshift/2 to Nshift/2 - 1.

if ~exist('check_ssfp', 'var')
    check_ssfp = false;
end
if ~exist('average_ssfp', 'var')
    average_ssfp = true;
end

% This section is meaningful only if there is a single spectral line.
% ======================================================================
if check_ssfp == true

    % 4) Find the theoretical vs. actual steady-state signal.
    % Relevant only if there is a single line.
    % ===========================================================
    % Nshift shifted ssfp signals by 2*pi/Nshift radians.
    Phssfp_Nshift = Nshift_vec*2*pi/Nshift;
    ssfp_Nshift = ssfp_signal(T1, T2, TR, FA, Phssfp_Nshift, RF_phase);

    % shift phi further by the off-resonance*TR.
    ssfp_shift = angle(exp(1j*2*pi*offres/(1000/TR)))/(2*pi)*Nshift;
    ssfp_Nshift = circshift(ssfp_Nshift, -round(ssfp_shift));

    if average_ssfp == true
        opt_angle = zeros(1, size(measured_data, 1));
        ssfpf = zeros(size(measured_data));

        % ssfp signal, ssfp2, after constant phase correction by opt_angle.
        for jj = 1:size(measured_data, 1)
            [ssfpf(jj, :), opt_angle(jj)] = call_fmin_angle(measured_data(jj, :), ssfp_Nshift);
        end

        % calculate offres from the slope of opt_angle.
        offres_sim_Hz = mean(diff(unwrap(opt_angle*pi/180)))*sw/(2*pi);
        offres_sim_Hz = round(offres_sim_Hz);
        ssfp2 = mean(ssfpf);       % average over all data rows.
    else
        % ssfp signal, ssfp2, after constant phase correction by opt_angle.
        [ssfp2, opt_angle] = call_fmin_angle(measured_data(1, :), ssfp_Nshift);
    end

    % plot theoretical vs actual signal.
    % ==========================================
    h = figure;
    factor = 2.0;       % increase figure height by factor.
    set(h, 'Name', [' offres = ', num2str(offres, 4), ' Hz. skipped lines = ', num2str(lines_to_skip), '. ']);
    aa  = get(h, 'Position');
    aa(2) = aa(2) - (factor - 1)*aa(4) + 10;
    aa(4) = factor*aa(4);
    set(h, 'Position', aa);
    subplot(2, 1, 1);
    plot(Phssfp_Nshift, real(ssfp_Nshift)), hold on, plot(Phssfp_Nshift, imag(ssfp_Nshift), 'm');
    grid on;
    title([' calculated ssfp:  FA = ', num2str(FA, 3), ' deg. '])
    legend(' real ', ' imag ', 'Location', 'southeast')
    ax = gca;
    ax.XTick = pi*linspace(-1, 1, 5);
    ax.XMinorTick = 'on';
    ax.XTickLabel = {'-\pi', '-0.5\pi', '0', '0.5\pi', '\pi'};
    % title([' freq of FIR filter: Nshift = ', num2str(Nshift), '. ']);
    xlabel(' \phi (rad) ');
    xlabel(' RF shift angle ')
    subplot(2, 1, 2);
    plot(Phssfp_Nshift, real(ssfp2)), hold on, plot(Phssfp_Nshift, imag(ssfp2), 'm');
    grid on;
    ax = gca;
    ax.XTick = pi*linspace(-1, 1, 5);
    ax.XMinorTick = 'on';
    ax.XTickLabel = {'-\pi', '-0.5\pi', '0', '0.5\pi', '\pi'};
    if average_ssfp == true
        title([' measured ssfp. offres\_sim\_Hz = ', num2str(offres_sim_Hz), ', offres = ', num2str(offres), '. ']);
    else
        title([' measured ssfp. opt angle = ', num2str(opt_angle, 3), ' deg. '])
    end
    legend(' real ', ' imag ', 'Location', 'southeast')
    xlabel(' RF shift angle ')
end


% find down sampling factor
% ==================================
Ratio = min(size(data, 1)/min_npoints, sw/min_sw_Hz);
% find data_rows (approximately size(data, 1)) and us_factor (approximately
% Ratio), such that data_rows is an integer and data_rows/us_factor is an
% integer.
[data_rows, us_factor] = find_N_R(size(data, 1), Ratio);

% limit sw_display_Hz.
sw_display_Hz(1) = max(-sw/(2*us_factor), sw_display_Hz(1));
sw_display_Hz(2) = min(sw/(2*us_factor), sw_display_Hz(2));

if us_factor >= 1
    data = data(1:data_rows, :);
    fdata = ifftc(data);
    fdata1 = crop(fdata, [data_rows/us_factor, size(data, 2)]);
    data = fftc(fdata1);
end

data_orig = data;

% lambda value for L1 regularization
% ========================================
if ~exist('lambda_val', 'var')
    % lambda_val = 1e-3;       % for L2 normalization.
    % lambda_val = 2.5e-4;       % for L1 normalization.
    % lambda_val = 5e-5;       % for L1 normalization.
end

% use G'*G to calculate proximal gradient step size (Lipshitz constant) L if calc_GtG = 1.
% =========================================================================================
calc_GtG = 0;

Dt = 1/(sw*1e-3);       % in msec.
samp_points = size(data, 1);
dt = us_factor*Dt;
points_in_TR = round(TR/dt);
% norm_time = ((0:samp_points - 1) + Nskip)*dt/TR;       % sampling time in TR.
norm_time = (0:samp_points - 1)*dt/TR;       % sampling time in TR.

% number of frequency bins in the full range of 1/dwell_time = 1/dt.
freq_bins_full = points_in_TR*band_num;

freq_full = ((0:freq_bins_full - 1) - freq_bins_full/2)/(freq_bins_full*dt);       % full freq axis in kHz.

% frequency axis (kHz) between min(sw_display_Hz)*1e-3 to max(sw_display_Hz)*1e-3.
freq = freq_full(freq_full >= min(sw_display_Hz)*1e-3 & freq_full <= max(sw_display_Hz)*1e-3);
if isempty(freq)
    errordlg(' sw_display_Hz is not defined correctly. ')
    return
end

Ph = freq*2*pi*TR;
freq_bins = length(Ph);

shifted_Ph = Ph(:) + Nshift_vec*2*pi/Nshift;       % size: freq_bins by Nshift.

% 1) km calculation.
% =======================
% km values.
km = -4:7;

if run_km_calib == true && length(km) > 2
    % find A, D, G and L for optimal km calculation.
    G = zeros(samp_points*Nshift, freq_bins);
    D = reshape(ssfp_signal(T1, T2, TR, FA, shifted_Ph(:), RF_phase), freq_bins, []);
    A = exp(-1j*norm_time(:)*Ph);
    for jj = 1:freq_bins
        tmp = A(:, jj).*D(jj, :);
        G(:, jj) = tmp(:);
    end

    % calculate Lipshitz constant L.
    if calc_GtG > 0       % calc L = (||G||op)^2. See link 1.
        GtG = G'*G;
        L = max(eig(GtG));
    else
        L = norm(G, 'fro');       % approximate (||G||op)^2. See link 2.
    end

    min_cost = zeros(1, length(km));
    norm_Res = zeros(1, length(km));

    tStart = tic;
    % km loop.
    if length(km) > 1; wb = waitbar(0, ' running km calibration ... '); end
    for j = 1:length(km)
        cexp_vec = exp(-2*pi*1i*norm_time(:)*km(j)/Nshift);
        data = data_orig.*cexp_vec;
        m_data = data(:);

        % normalize data.
        m_data = m_data/norm(m_data, 1);       % L1.
        % m_data = m_data/norm(m_data);       % L2.

        GtY = G'*m_data;

        if j == 1
            Xinit = zeros(freq_bins, 1);
        else
            Xinit = Res;
        end

        for kloop = 1:8
            [Res, min_cost(j), cost_vec, err] = lasso_fista(m_data, G, GtY, L, Xinit, opts_km);
            if err == -1       % in case of no convergence.
                disp(' err = -1. ')
                % L = max(eig(G'*G));
                L = 2*L;       % increase L to improve convergence.
            else
                break
            end
        end
        if kloop == 8
            errordlg(' fista does not converge.')
            return
        end

        norm_Res(j) = norm(Res, 1);

        if length(km) > 1; waitbar(j/length(km)); end
        % disp([' iterations = ', num2str(length(cost_vec)), '. '])
        % disp([' min_cost = ', num2str(min_cost(j), 5), '. '])
        % disp([' norm Res = ', num2str(norm_Res(j), 5), '. '])
        % disp([' err = ', num2str(err, 4), '. '])
    end
    if length(km) > 1; close(wb); end

    disp([' optimal km calculation time = ', num2str(toc(tStart)), ' sec. ']);

    % find optimal km from minimum
    kmint = linspace(km(1), km(end), 1000);
    y = interp1(km, norm_Res, kmint, 'spline');
    km_opt = kmint(y == min(y));
    figure; plot(km, norm_Res, 'o', kmint, y, 'm'); xlabel(' km '); grid;
    title([' km\_opt = ', num2str(km_opt, 3), '. '])
else
    km_opt = 1.15;
end

% 2) spectrum calculation section
% ==========================================
cexp_vec = exp(-2*pi*1i*norm_time(:)*km_opt/Nshift);
data = data_orig.*cexp_vec;
m_data = data(:);

% normalize data.
m_data = m_data/norm(m_data, 1);       % L1.
% m_data = m_data/norm(m_data);       % L2.

phase_vec = ((0:n_phases - 1) - floor(n_phases/2))/n_phases*mean(diff(Ph));

R1 = zeros(length(Ph), n_phases);
G = zeros(samp_points*Nshift, freq_bins);

tStart = tic;
if n_phases > 1
    wb = waitbar(0, [' running Fista:  n\_phases = ', num2str(n_phases)], 'Name', [' points = ', num2str(samp_points), '. us factor = ', num2str(us_factor, 3)]);
end

for j = 1:n_phases
    % find A, D, G for spectrum.
    D = reshape(ssfp_signal(T1, T2, TR, FA, shifted_Ph(:) + phase_vec(j), RF_phase), freq_bins, []);       % size(Ds) = freq_bins by Nshift.
    A = exp(-1j*norm_time(:)*(Ph + phase_vec(j)));
    for jj = 1:size(A, 2)
        tmp = A(:, jj).*D(jj, :);
        G(:, jj) = tmp(:);
    end
    GtY = G'*m_data;

    % calculate Lipshitz constant L.
    if j == 1
        if calc_GtG > 0       % calc L = (||G||op)^2. See link 1.
            GtG = G'*G;
            L = max(eig(GtG));
        else
            L = norm(G, 'fro');       % approximate (||G||op)^2. See link 2.
        end
    end

    if j == 1
        Xinit = zeros(freq_bins, 1);
        % else
        %     Xinit = R1(:, j - 1);
    end

    for kloop = 1:8
        [R1(:, j), min_cost, cost_vec, err] = lasso_fista(m_data, G, GtY, L, Xinit, opts);
        % [x1, w1] = sr3(G, m_data, 'lam', opts.lambda/5, 'ptf', 1);
        if err == -1       % in case of no convergence.
            disp(' err = -1. ')
            % L = max(eig(G'*G));
            L = 2*L;       % increase L to improve convergence.
        else
            break
        end
    end
    if kloop == 8
        errordlg(' fista does not converge.')
        return
    end

    % [R1(:, j), fn, pixel_shifts, Dval] = ring_remove(fftc(R1(:, j)), 2, 15, 1/8);
    disp([' iterations = ', num2str(length(cost_vec)), '. '])
    disp([' min_cost = ', num2str(min_cost, 5), '. '])
    disp([' err = ', num2str(err, 4), '. '])
    if n_phases > 1; f = waitbar(j/n_phases); end
end

disp([' t_fista = ', num2str(toc(tStart)), ' sec. ']);
close(f)

if n_phases > 1
    Res = zeros(size(R1, 1)*n_phases, size(R1, 2));
    len = size(R1, 1);
    for j = 1:n_phases
        Res(1:end - n_phases + 1, j) = interp1(abs(R1(:, j)), 1:1/n_phases:len);
        Res(:, j) = circshift(Res(:, j), j - 1);
    end
    Tmp = mean(Res, 2);
    % set first n_phases - 1 points to first n_phases - 1 of Res(:, 1).
    % set last n_phases - 1 points to last n_phases - 1 of Res(:, n_phases).
    Tmp(1:n_phases - 1) = Res(1:n_phases - 1, 1);
    Tmp(end - n_phases + 2:end) = Res(end - n_phases + 2, n_phases);
    Res = Tmp;
else
    Res = R1;
end

freq1 = linspace(freq(1), freq(end), n_phases*length(freq));
lambda_max = norm(G'*m_data, inf);       % max possible lambda where X = 0. see paper 1) and l1_ls package.

ppm1_axis = (freq1*1000 + freq_read)/larmor_freq;
Res_flip = flip(Res); % flipping the spectrum so it correlates to the ppm axis

expt = NS*Nshift*TR/1000; % since TR is given in ms we convert to seconds
[SNR_fista,SNRt_fista] = SNR_calc(Res,ppm1_axis,expt,signal_bound_ppm,noise_bound_ppm);


clear('params');
params.directory = directory;
params.T1 = T1;
params.T2 = T2;
params.TR = TR;
params.down_samp_factor = us_factor;
params.min_sw_Hz = min_sw_Hz;
params.sw = sw;
params.down_sampled_sw = sw/us_factor;
params.bw_display_Hz = sw_display_Hz;
params.Nshift = Nshift;
params.band_num = band_num;
params.FA = FA;
params.optimal_km = km_opt;
params.points_in_TR = points_in_TR;
params.samp_points = samp_points;
params.lines_to_skip = lines_to_skip;
params.acquired_lines = acquired_lines;
params.Nskip = Nskip;
params.lines_to_skip = lines_to_skip;
params.RF_phase = RF_phase;
params.n_phases = n_phases;
params.freq_res_Hz = (freq(end) - freq(1))*1e3/freq_bins;
params.check_ssfp = check_ssfp;
params.average_ssfp = average_ssfp;
params.lambda_max = lambda_max;       % l1_ls package.
params.lambda_fista = opts.lambda;
params.num_iter = length(cost_vec);
params.cost = min_cost;
params.err = err;
params.freq_read = freq_read;
params.larmor_freq = larmor_freq;



% clear struct "opts" for next time.
if exist('opts', 'var')
    clear('opts');
end

% clear 'check_ssfp' for next time.
if exist('check_ssfp', 'var')
    clear('check_ssfp');
end

% clear 'lambda_val' for next time.
if exist('lambda_val', 'var')
    clear('lambda_val');
end
